{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.5 FETI-DP in NGSolve IV: Inexact FETI-DP\n",
    "\n",
    "This time, we will, instead of going to the schur-complement for $\\lambda$,\n",
    "directly iterate on the equation\n",
    "$$\n",
    "\\left(\n",
    "\\begin{matrix}\n",
    "A_{\\scriptscriptstyle DP} & B^T \\\\\n",
    "B & 0\n",
    "\\end{matrix}\n",
    "\\right)\n",
    "\\left(\n",
    "\\begin{matrix}\n",
    "u\\\\\n",
    "\\lambda\n",
    "\\end{matrix}\n",
    "\\right)\n",
    "=\n",
    "\\left(\n",
    "\\begin{matrix}\n",
    "f \\\\\n",
    "0\n",
    "\\end{matrix}\n",
    "\\right)\n",
    "$$\n",
    "\n",
    "The preconditioner will be \n",
    "$$\n",
    "\\widehat{M}^{-1} = \n",
    "\\left(\n",
    "\\begin{matrix}\n",
    "A_{\\scriptscriptstyle DP}^{-1} & 0 \\\\\n",
    "-M_s^{-1} B A_{\\scriptscriptstyle DP}^{-1} & M_s^{-1} \\\\\n",
    "\\end{matrix}\n",
    "\\right)\n",
    "$$\n",
    "\n",
    "As $\\widehat{M}^{-1}$ is not symmetric, we will use GMRES.\n",
    "\n",
    "For setting up the preconditioner, we only need pieces we already have.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_procs = '100'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from usrmeeting_jupyterstuff import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_cluster()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Waiting for connection file: ~/.ipython/profile_ngsolve/security/ipcontroller-kogler-client.json\n",
      "connecting ... try:6 succeeded!"
     ]
    }
   ],
   "source": [
    "start_cluster(num_procs)\n",
    "connect_cluster()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "from ngsolve import *\n",
    "import netgen.meshing as ngmeshing\n",
    "from ngsolve.la import ParallelMatrix, FETI_Jump, SparseMatrixd, ParallelDofs, BlockMatrix\n",
    "from dd_toolbox import FindFEV, DPSpace_Inverse, ScaledMat, LinMat\n",
    "\n",
    "def load_mesh(nref=0):\n",
    "    ngmesh = ngmeshing.Mesh(dim=3)\n",
    "    ngmesh.Load('cube.vol')\n",
    "    for l in range(nref):\n",
    "        ngmesh.Refine()\n",
    "    return Mesh(ngmesh)\n",
    "\n",
    "def setup_space(mesh, order=1):\n",
    "    comm = MPI_Init()\n",
    "    fes = H1(mesh, order=order, dirichlet='right|top')\n",
    "    a = BilinearForm(fes)\n",
    "    u,v = fes.TnT()\n",
    "    a += SymbolicBFI(grad(u)*grad(v))\n",
    "    a.Assemble()\n",
    "    f = LinearForm(fes)\n",
    "    f += SymbolicLFI(x*y*v)\n",
    "    f.Assemble()\n",
    "    avg_dof = comm.Sum(fes.ndof) / comm.size\n",
    "    if comm.rank==0:\n",
    "        print('global,  ndof =', fes.ndofglobal, ', lodofs =', fes.lospace.ndofglobal)\n",
    "        print('avg DOFs per core: ', avg_dof)\n",
    "    return [fes, a, f]\n",
    "\n",
    "def setup_FETIDP(fes, a):\n",
    "    faces, edges, vertices = FindFEV(mesh.dim, mesh.nv, \\\n",
    "                                     fes.ParallelDofs(), fes.FreeDofs())\n",
    "    primal_dofs = BitArray([ v in set(vertices) for v in range(fes.ndof) ]) & fes.FreeDofs() \n",
    "    dp_pardofs = fes.ParallelDofs().SubSet(primal_dofs)\n",
    "    ar = [(num_e[0],d,1.0) for num_e in enumerate(edges) for d in num_e[1] ]\n",
    "    rows, cols, vals = [list(x) for x in zip(*ar)] if len(ar) else [[],[],[]]\n",
    "    B_p = SparseMatrixd.CreateFromCOO(rows, cols, vals, len(edges), fes.ndof)\n",
    "    edist_procs = [sorted(set.intersection(*[set(fes.ParallelDofs().Dof2Proc(v)) for v in edge])) for edge in edges]\n",
    "    eavg_pardofs = ParallelDofs(edist_procs, comm)\n",
    "    nprim = comm.Sum(sum([1 for k in range(fes.ndof) if primal_dofs[k] and comm.rank<fes.ParallelDofs().Dof2Proc(k)[0] ]))\n",
    "    if comm.rank==0:\n",
    "        print('# of global primal dofs: ', nprim)  \n",
    "    A_dp = ParallelMatrix(a.mat.local_mat, dp_pardofs)\n",
    "    dual_pardofs = fes.ParallelDofs().SubSet(BitArray(~primal_dofs & fes.FreeDofs()))\n",
    "    B = FETI_Jump(dual_pardofs, u_pardofs=dp_pardofs)\n",
    "    if comm.rank==0:\n",
    "        print('# of global multipliers = :', B.col_pardofs.ndofglobal)\n",
    "    A_dp_inv = DPSpace_Inverse(mat=a.mat, freedofs=fes.FreeDofs(), \\\n",
    "                               c_points=primal_dofs, \\\n",
    "                               c_mat=B_p, c_pardofs=eavg_pardofs, \\\n",
    "                               invtype_loc='sparsecholesky', \\\n",
    "                               invtype_glob='masterinverse')\n",
    "    F = B @ A_dp_inv @ B.T\n",
    "    innerdofs = BitArray([len(fes.ParallelDofs().Dof2Proc(k))==0 for k in range(fes.ndof)]) & fes.FreeDofs()\n",
    "    A = a.mat.local_mat\n",
    "    Aiinv = A.Inverse(innerdofs, inverse='sparsecholesky')\n",
    "    scaledA = ScaledMat(A, [0 if primal_dofs[k] else 1.0/(1+len(fes.ParallelDofs().Dof2Proc(k))) for k in range(fes.ndof)])\n",
    "    scaledBT = ScaledMat(B.T, [0 if primal_dofs[k] else 1.0/(1+len(fes.ParallelDofs().Dof2Proc(k))) for k in range(fes.ndof)])\n",
    "    Fhat = B @ scaledA @ (IdentityMatrix() - Aiinv @ A) @ scaledBT\n",
    "    nFhat = B @ scaledA @ (Aiinv @ A - IdentityMatrix()) @ scaledBT\n",
    "    return [A_dp, A_dp_inv, F, Fhat, nFhat, B, scaledA, scaledBT]  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:1] \n",
      "global,  ndof = 576877 , lodofs = 75403\n",
      "avg DOFs per core:  6607.56\n",
      "# of global primal dofs:  399\n",
      "# of global multipliers = : 87867\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "comm = MPI_Init()\n",
    "mesh = load_mesh(nref=1)\n",
    "fes, a, f = setup_space(mesh, order=2)\n",
    "A_dp, A_dp_inv, F, Fhat, nFhat, B, scaledA, scaledBT = setup_FETIDP(fes, a)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting up the block-matrices\n",
    "\n",
    "For setting up the saddle point system and the preconditioner, \n",
    "we can use BlockMatrix.\n",
    "\n",
    "We could implement $\\widehat{M}^{-1}$ more efficiently as\n",
    "$$\n",
    "\\widehat{M}^{-1}\n",
    "=\n",
    "\\left(\n",
    "\\begin{matrix}\n",
    "I & 0 \\\\\n",
    "0 & M_s^{-1}\n",
    "\\end{matrix}\n",
    "\\right)\n",
    "\\cdot\n",
    "\\left(\n",
    "\\begin{matrix}\n",
    "I & 0 \\\\\n",
    "B & I\n",
    "\\end{matrix}\n",
    "\\right)\n",
    "\\cdot\n",
    "\\left(\n",
    "\\begin{matrix}\n",
    "A_{\\scriptscriptstyle DP}^{-1} & 0 \\\\\n",
    "0 & I\n",
    "\\end{matrix}\n",
    "\\right)\n",
    "$$\n",
    "\n",
    "This would mean that we only have to apply $A_{\\scriptscriptstyle DP}^{-1}$ and $M_s^{-1}$ \n",
    "once instead of twice.\n",
    "\n",
    "However, this way we would still need multiple unnecessary vector copies.\n",
    "\n",
    "We can avoid both double applications and annecessary vector copies if we really quickly implement this ourselfs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "M = BlockMatrix([[A_dp, B.T], \\\n",
    "                 [B, None]])\n",
    "Mhat = BlockMatrix([[A_dp_inv, None], \\\n",
    "                    [Fhat @ B @ A_dp_inv, nFhat]])\n",
    "class Mhat_v2(BaseMatrix):\n",
    "    def __init__(self, Ahat, B, Fhat):\n",
    "        super(Mhat_v2, self).__init__()\n",
    "        self.Ahat = Ahat\n",
    "        self.B = B\n",
    "        self.Fhat = Fhat\n",
    "        self.hv = Fhat.CreateColVector()\n",
    "    def Mult(self, x, y):\n",
    "        y[0].data = self.Ahat * x[0]\n",
    "        self.hv.data = x[1] - self.B * y[0]\n",
    "        y[1].data = self.Fhat * self.hv\n",
    "    def CreateRowVector(self):\n",
    "        return BlockVector([self.Ahat.CreateRowVector(), self.Fhat.CreateRowVector()])\n",
    "    def CreateColVector(self):\n",
    "        return BlockVector([self.Ahat.CreateRowVector(), self.Fhat.CreateRowVector()])\n",
    "Mhat2 = Mhat_v2(A_dp_inv, B, Fhat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There is one small problem. \n",
    "We have to use the C++-side GMRES, simply because that is currently \n",
    "the only one available to us form NGSolve.\n",
    "\n",
    "BlockMatrix does not play nice with GMRES, because BlockMatrix works with\n",
    "BlockVectors, and GMRES expects standard NGSolve (Parallel-)Vectors.\n",
    "\n",
    "\"LinMat\" is a *temporary* simple workaround, that just copies between BlockVectors and\n",
    "normal \"linear\" vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "M_lin = LinMat(M, [A_dp.col_pardofs, B.col_pardofs])\n",
    "Mhat_lin = LinMat(Mhat, [A_dp.col_pardofs, B.col_pardofs])\n",
    "Mhat2_lin = LinMat(Mhat2, [A_dp.col_pardofs, B.col_pardofs])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One last annoyance is that jupyter-notebooks do not capture\n",
    "C++ stdout, so we lose the output during the iteration.\n",
    "\n",
    "**this only affects us inside the notebooks**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:1] \n",
      "took 13 steps for v1\n",
      "time solve v1:  1.1863238600781187\n",
      "dofs per proc and second v1:  4862.727788025885\n",
      "\n",
      "took 13 steps for v2\n",
      "time solve v2:  0.5696066429372877\n",
      "dofs per proc and second v2:  10127.63820704796\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "sol = M_lin.CreateRowVector()\n",
    "rhs = sol.CreateVector()\n",
    "rhs[:] = 0.0\n",
    "rhs[0:fes.ndof] = f.vec.local_vec\n",
    "rhs.SetParallelStatus(f.vec.GetParallelStatus())\n",
    "ngsglobals.msg_level = 3\n",
    "t1 = -comm.WTime()\n",
    "gmr = GMRESSolver(mat=M_lin, pre=Mhat_lin, maxsteps=100,\\\n",
    "                  precision=1e-6, printrates=True)\n",
    "sol.data = gmr * rhs\n",
    "t1 += comm.WTime()\n",
    "nsteps1 = gmr.GetSteps()\n",
    "t2 = -comm.WTime()\n",
    "gmr = GMRESSolver(mat=M_lin, pre=Mhat2_lin, maxsteps=100,\\\n",
    "                  precision=1e-6, printrates=True)\n",
    "sol.data = gmr * rhs\n",
    "t2 += comm.WTime()\n",
    "nsteps2 = gmr.GetSteps()\n",
    "if comm.rank==0:\n",
    "    print('\\ntook', nsteps1, 'steps for v1')\n",
    "    print('time solve v1: ', t1)\n",
    "    print('dofs per proc and second v1: ', fes.ndofglobal / ( t1 * comm.size))\n",
    "    print('\\ntook', nsteps2, 'steps for v2')\n",
    "    print('time solve v2: ', t2)\n",
    "    print('dofs per proc and second v2: ', fes.ndofglobal / ( t2 * comm.size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dummy - AllReduce :   1.266655683517456\n",
      "dummy - AllReduce :   0.8690826892852783\n"
     ]
    }
   ],
   "source": [
    "%%px --target 1\n",
    "for t in sorted(filter(lambda t:t['time']>0.5, Timers()), key=lambda t:t['time'], reverse=True):\n",
    "    print(t['name'], ':  ', t['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:12] \n",
      "timers from rank  13 :\n",
      "SparseCholesky<d,d,d>::MultAdd :   1.9462761878967285\n",
      "SparseCholesky<d,d,d>::MultAdd fac1 :   1.1591949462890625\n",
      "SparseCholesky<d,d,d>::MultAdd fac2 :   0.7608258724212646\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "t_chol = filter(lambda t: t['name'] == 'SparseCholesky<d,d,d>::MultAdd', Timers()).__next__()\n",
    "maxt = comm.Max(t_chol['time']) \n",
    "if t_chol['time'] == maxt:\n",
    "    print('timers from rank ', comm.rank, ':')\n",
    "    for t in sorted(filter(lambda t:t['time']>min(0.3*maxt, 0.5), Timers()), key=lambda t:t['time'], reverse=True):\n",
    "        print(t['name'], ':  ', t['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:0] max timer rank 14 :  SparseCholesky<d,d,d>::MultAdd   1.7368290424346924\n",
      "[stdout:1] max timer rank 0 :  dummy - AllReduce   1.266655683517456\n",
      "[stdout:2] max timer rank 18 :  SparseCholesky<d,d,d>::MultAdd   1.413224220275879\n",
      "[stdout:3] max timer rank 8 :  SparseCholesky<d,d,d>::MultAdd   1.675055980682373\n",
      "[stdout:4] max timer rank 17 :  SparseCholesky<d,d,d>::MultAdd   1.814605951309204\n",
      "[stdout:5] max timer rank 16 :  SparseCholesky<d,d,d>::MultAdd   1.579732894897461\n",
      "[stdout:6] max timer rank 2 :  SparseCholesky<d,d,d>::MultAdd   0.8424544334411621\n",
      "[stdout:7] max timer rank 12 :  SparseCholesky<d,d,d>::MultAdd   1.098663091659546\n",
      "[stdout:8] max timer rank 10 :  dummy - AllReduce   0.6785235404968262\n",
      "[stdout:9] max timer rank 6 :  SparseCholesky<d,d,d>::MultAdd   1.760845422744751\n",
      "[stdout:10] max timer rank 1 :  SparseCholesky<d,d,d>::MultAdd   1.5053730010986328\n",
      "[stdout:11] max timer rank 19 :  SparseCholesky<d,d,d>::MultAdd   1.6160402297973633\n",
      "[stdout:12] max timer rank 13 :  SparseCholesky<d,d,d>::MultAdd   1.9462761878967285\n",
      "[stdout:13] max timer rank 15 :  SparseCholesky<d,d,d>::MultAdd   1.866727352142334\n",
      "[stdout:14] max timer rank 4 :  SparseCholesky<d,d,d>::MultAdd   1.491440773010254\n",
      "[stdout:15] max timer rank 5 :  SparseCholesky<d,d,d>::MultAdd   1.719233751296997\n",
      "[stdout:16] max timer rank 3 :  SparseCholesky<d,d,d>::MultAdd   1.3392870426177979\n",
      "[stdout:17] max timer rank 11 :  SparseCholesky<d,d,d>::MultAdd   0.9613127708435059\n",
      "[stdout:18] max timer rank 7 :  SparseCholesky<d,d,d>::MultAdd   1.774916648864746\n",
      "[stdout:19] max timer rank 9 :  SparseCholesky<d,d,d>::MultAdd   1.0627970695495605\n"
     ]
    }
   ],
   "source": [
    "%%px --targets 0:20\n",
    "maxt = max(Timers(), key=lambda t:t['time'])\n",
    "if maxt['time']>min(0.3*maxt['time'], 0.5):\n",
    "    print('max timer rank', comm.rank, ': ', maxt['name'], ' ', maxt['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_cluster()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
